#
# Inception-ResNet-V1 network components
# Details are in https://arxiv.org/pdf/1602.07261v2.pdf
#

InceptionResNetV1(input, labelDim, bnTimeConst) =
{
    #
    # Stem
    #
    # 299 x 299 x 3
    conv1 = ConvBNReLULayer{32, (3:3), (2:2), false, bnTimeConst}(input)
    # 149 x 149 x 32
    conv2 = ConvBNReLULayer{32, (3:3), (1:1), false, bnTimeConst}(conv1)
    # 147 x 147 x 32
    conv3 = ConvBNReLULayer{64, (3:3), (1:1), true, bnTimeConst}(conv2)
    # 147 x 147 x 64
    pool1 = MaxPoolingLayer{(3:3), stride = (2:2), pad = false}(conv3)
    # 73 x 73 x 64
    conv4 = ConvBNReLULayer{80, (1:1), (1:1), true, bnTimeConst}(pool1)
    # 73 x 73 x 80
    conv5 = ConvBNReLULayer{192, (3:3), (1:1), false, bnTimeConst}(conv4)
    # 71 x 71 x 192
    conv6 = ConvBNReLULayer{256, (3:3), (2:2), false, bnTimeConst}(conv5)
    # 35 x 35 x 256

    #
    # Inception Blocks
    #
    # 35 x 35 x 256
    inceptionResNetA_1 = InceptionResNetA{bnTimeConst}(conv6)
    inceptionResNetA_2 = InceptionResNetA{bnTimeConst}(inceptionResNetA_1)
    inceptionResNetA_3 = InceptionResNetA{bnTimeConst}(inceptionResNetA_2)
    inceptionResNetA_4 = InceptionResNetA{bnTimeConst}(inceptionResNetA_3)
    inceptionResNetA_5 = InceptionResNetA{bnTimeConst}(inceptionResNetA_4)
    # 35 x 35 x 256
    reduction_1 = ReductionA{192, 192, 256, 384, bnTimeConst}(inceptionResNetA_5)
    # 17 x 17 x 896
    inceptionResNetB_1 = InceptionResNetB{bnTimeConst}(reduction_1)
    inceptionResNetB_2 = InceptionResNetB{bnTimeConst}(inceptionResNetB_1)
    inceptionResNetB_3 = InceptionResNetB{bnTimeConst}(inceptionResNetB_2)
    inceptionResNetB_4 = InceptionResNetB{bnTimeConst}(inceptionResNetB_3)
    inceptionResNetB_5 = InceptionResNetB{bnTimeConst}(inceptionResNetB_4)
    inceptionResNetB_6 = InceptionResNetB{bnTimeConst}(inceptionResNetB_5)
    inceptionResNetB_7 = InceptionResNetB{bnTimeConst}(inceptionResNetB_6)
    inceptionResNetB_8 = InceptionResNetB{bnTimeConst}(inceptionResNetB_7)
    inceptionResNetB_9 = InceptionResNetB{bnTimeConst}(inceptionResNetB_8)
    inceptionResNetB_10 = InceptionResNetB{bnTimeConst}(inceptionResNetB_9)
    # 17 x 17 x 896
    reduction_2 = ReductionB{bnTimeConst}(inceptionResNetB_10)
    # 8 x 8 x 1792
    InceptionResNetC_1 = InceptionResNetC{bnTimeConst}(reduction_2)
    InceptionResNetC_2 = InceptionResNetC{bnTimeConst}(InceptionResNetC_1)
    InceptionResNetC_3 = InceptionResNetC{bnTimeConst}(InceptionResNetC_2)
    InceptionResNetC_4 = InceptionResNetC{bnTimeConst}(InceptionResNetC_3)
    InceptionResNetC_5 = InceptionResNetC{bnTimeConst}(InceptionResNetC_4)

    #
    # Prediction
    #
    # 8 x 8 x 1792
    pool2 = AveragePoolingLayer{(8:8)}(InceptionResNetC_5)
    # 1 x 1 x 1792
    drop = Dropout(pool2)
    # 1 x 1 x 1792
    z = LinearLayer{labelDim, init = 'heNormal'}(drop)
}

#
# Inception-ResNet-V1 model with normalized input, to use the below function
# remove "ImageNet1K_mean.xml" from each reader.
#
InceptionResNetV1Norm(input, labelDim, bnTimeConst) = 
{
    # Normalize inputs to -1 and 1.
    featMean  = 128
    featScale = 1/128
    Normalize{m,f} = x => f .* (x - m)
            
    inputNorm = Normalize{featMean, featScale}(input)
    model     = InceptionResNetV1(inputNorm, labelDim, bnTimeConst)
}.model